/*
 * Copyright (c) 2021 Oracle and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.helidon.common.crypto;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.spec.ChaCha20ParameterSpec;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import io.helidon.common.Base64Value;
import io.helidon.common.LazyValue;

/**
 * This class provides simple and stateless way to encrypt and decrypt messages using selected symmetric cipher.
 * <br>
 * It requires to have base password provided. Unique cryptography key is generated by this implementation every
 * time it encrypts and decrypts. The key uniqueness is ensured by randomly generated salt which is added to the password.
 */
public class SymmetricCipher implements CommonCipher {

    /**
     * AES algorithm with CBC method and PKCS5 padding.
     * <br>
     * It is strongly advised to be used together with HMAC or other authenticated message digest.
     * <br>
     * Value is: {@value}.
     */
    public static final String ALGORITHM_AES_CBC = "AES/CBC/PKCS5Padding";

    /**
     * AES algorithm with CTR method and no padding.
     * <br>
     * It is strongly advised to be used together with HMAC or other authenticated message digest.
     * <br>
     * Value is: {@value}.
     */
    public static final String ALGORITHM_AES_CTR = "AES/CTR/NoPadding";

    /**
     * AES algorithm with GCM method and no padding.
     * <br>
     * Does not need to be used with any authenticated message digest since GCM is making its signature by the design.
     * <br>
     * Value is: {@value}.
     */
    public static final String ALGORITHM_AES_GCM = "AES/GCM/NoPadding";

    /**
     * ChaCha20 encryption algorithm.
     * <br>
     * Value is: {@value}.
     */
    public static final String ALGORITHM_CHA_CHA = "ChaCha20";

    /**
     * ChaCha20 encryption algorithm with Poly1305 authentication code.
     * <br>
     * Value is: {@value}.
     */
    public static final String ALGORITHM_CHA_CHA_POLY1305 = "ChaCha20-Poly1305";

    private static final LazyValue<SecureRandom> SECURE_RANDOM = LazyValue.create(SecureRandom::new);
    private static final Pattern PATTERN_ALGORITHM = Pattern.compile("^(\\S+)/\\S+/\\S+$");
    private static final int SALT_LENGTH = 16;

    private final String algorithm;
    private final String provider;
    private final char[] password;
    private final int keySize;
    private final int numberOfIterations;

    private SymmetricCipher(Builder builder) {
        this.algorithm = builder.algorithm;
        this.provider = builder.provider;
        this.password = builder.password;
        this.keySize = builder.keySize;
        this.numberOfIterations = builder.numberOfIterations;
    }

    /**
     * Create a new builder.
     *
     * @return new builder
     */
    public static Builder builder() {
        return new Builder();
    }

    /**
     * Create a new instance based on the password.
     *
     * Default used algorithm is {@link #ALGORITHM_AES_GCM}.
     *
     * @param password password
     * @return new instance
     */
    public static SymmetricCipher create(char[] password) {
        return new Builder().password(password).build();
    }

    /**
     * Encrypt the message with the usage of provided parameters.
     *
     * @param algorithm algorithm to be used
     * @param key       key to encrypt message
     * @param iv        initialization vector
     * @param plain     message
     * @return encrypted message
     */
    public static Base64Value encrypt(String algorithm, byte[] key, byte[] iv, Base64Value plain) {
        return encrypt(algorithm, null, key, iv, plain);
    }

    /**
     * Encrypt the message with the usage of provided parameters.
     *
     * @param algorithm algorithm to be used
     * @param provider  provider of the algorithm
     * @param key       key to encrypt message
     * @param iv        initialization vector
     * @param plain     message
     * @return encrypted message
     */
    public static Base64Value encrypt(String algorithm, String provider, byte[] key, byte[] iv, Base64Value plain) {
        Objects.requireNonNull(algorithm, "Algorithm cannot be null");
        Objects.requireNonNull(iv, "Initialization vector cannot be null");
        return encrypt(algorithm, provider, key, createAlgorithmParameter(algorithm, iv), plain);
    }

    /**
     * Encrypt the message with the usage of provided parameters.
     *
     * @param algorithm algorithm to be used
     * @param provider  provider of the algorithm
     * @param key       key to encrypt message
     * @param params    cipher parameter object
     * @param plain     message
     * @return encrypted message
     */
    public static Base64Value encrypt(String algorithm,
                                      String provider,
                                      byte[] key,
                                      AlgorithmParameterSpec params,
                                      Base64Value plain) {
        Objects.requireNonNull(algorithm, "Algorithm cannot be null");
        Objects.requireNonNull(key, "Key cannot be null");
        Objects.requireNonNull(params, "Algorithm parameters cannot be null");
        Objects.requireNonNull(plain, "Plain content cannot be null");
        Cipher cipher = cipher(algorithm, provider, key, params, Cipher.ENCRYPT_MODE);
        try {
            return Base64Value.create(cipher.doFinal(plain.toBytes()));
        } catch (IllegalBlockSizeException | BadPaddingException e) {
            throw new CryptoException("Failed to encrypt the message", e);
        }
    }

    /**
     * Decrypt the message with the usage of provided parameters.
     *
     * @param algorithm algorithm to be used
     * @param key       key to decrypt message
     * @param iv        encrypted message initialization vector
     * @param encrypted encrypted message
     * @return decrypted message
     */
    public static Base64Value decrypt(String algorithm, byte[] key, byte[] iv, Base64Value encrypted) {
        return decrypt(algorithm, null, key, iv, encrypted);
    }

    /**
     * Decrypt the message with the usage of provided parameters.
     *
     * @param algorithm algorithm to be used
     * @param provider  algorithm provider
     * @param key       key to decrypt message
     * @param iv        encrypted message initialization vector
     * @param encrypted encrypted message
     * @return decrypted message
     */
    public static Base64Value decrypt(String algorithm, String provider, byte[] key, byte[] iv, Base64Value encrypted) {
        Objects.requireNonNull(algorithm, "Algorithm cannot be null");
        Objects.requireNonNull(iv, "Initialization vector cannot be null");
        return decrypt(algorithm, provider, key, createAlgorithmParameter(algorithm, iv), encrypted);
    }

    /**
     * Decrypt the message with the usage of provided parameters.
     *
     * @param algorithm algorithm to be used
     * @param provider  algorithm provider
     * @param key       key to decrypt message
     * @param params    cipher parameter object
     * @param encrypted encrypted message
     * @return decrypted message
     */
    public static Base64Value decrypt(String algorithm,
                                      String provider,
                                      byte[] key,
                                      AlgorithmParameterSpec params,
                                      Base64Value encrypted) {
        Objects.requireNonNull(algorithm, "Algorithm cannot be null");
        Objects.requireNonNull(key, "Key cannot be null");
        Objects.requireNonNull(encrypted, "Encrypted content cannot be null");
        Cipher cipher = cipher(algorithm, provider, key, params, Cipher.DECRYPT_MODE);
        try {
            return Base64Value.create(cipher.doFinal(encrypted.toBytes()));
        } catch (IllegalBlockSizeException | BadPaddingException e) {
            throw new CryptoException("Failed to decrypt the message", e);
        }
    }

    private static Cipher cipher(String algorithm,
                                 String provider,
                                 byte[] key,
                                 AlgorithmParameterSpec parameterSpec,
                                 int cipherMode) {
        try {
            Matcher matcher = PATTERN_ALGORITHM.matcher(algorithm);
            String keySpecAlg = matcher.matches() ? matcher.group(1) : algorithm;
            SecretKeySpec spec = new SecretKeySpec(key, keySpecAlg);
            Cipher cipher;
            if (provider == null) {
                cipher = Cipher.getInstance(algorithm);
            } else {
                cipher = Cipher.getInstance(algorithm, provider);
            }
            if (parameterSpec == null) {
                cipher.init(cipherMode, spec);
            } else {
                cipher.init(cipherMode, spec, parameterSpec);
            }
            return cipher;
        } catch (Exception e) {
            throw new CryptoException("Failed to prepare a cipher instance", e);
        }
    }

    private static AlgorithmParameterSpec createAlgorithmParameter(String algorithm, byte[] iv) {
        switch (algorithm) {
        case ALGORITHM_AES_GCM:
            return new GCMParameterSpec(128, iv);
        case ALGORITHM_CHA_CHA:
            return new ChaCha20ParameterSpec(iv, 1);
        default:
            return new IvParameterSpec(iv);
        }
    }

    @Override
    public Base64Value encrypt(Base64Value message) {
        Objects.requireNonNull(message, "Plain content cannot be null");
        try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
                DataOutputStream dataOutputStream = new DataOutputStream(outputStream)) {
            byte[] salt = new byte[SALT_LENGTH];
            SECURE_RANDOM.get().nextBytes(salt);
            byte[] key = PasswordKeyDerivation.deriveKey(password, salt, numberOfIterations, keySize);
            Cipher cipher = cipher(algorithm, provider, key, null, Cipher.ENCRYPT_MODE);
            byte[] iv = cipher.getIV();
            outputStream.writeBytes(salt);
            dataOutputStream.writeInt(iv.length);
            outputStream.writeBytes(iv);
            outputStream.writeBytes(cipher.doFinal(message.toBytes()));
            return Base64Value.create(outputStream.toByteArray());
        } catch (IOException | IllegalBlockSizeException | BadPaddingException e) {
            throw new CryptoException("An error occurred while message encryption", e);
        }
    }

    @Override
    public Base64Value decrypt(Base64Value encrypted) {
        Objects.requireNonNull(encrypted, "Encrypted content cannot be null");
        try (ByteArrayInputStream inputStream = new ByteArrayInputStream(encrypted.toBytes());
                DataInputStream dataInputStream = new DataInputStream(inputStream)) {
            byte[] salt = inputStream.readNBytes(SALT_LENGTH);
            int ivSize = dataInputStream.readInt();
            byte[] iv = inputStream.readNBytes(ivSize);
            byte[] key = PasswordKeyDerivation.deriveKey(password, salt, numberOfIterations, keySize);
            return decrypt(algorithm, provider, key, iv, Base64Value.create(inputStream.readAllBytes()));
        } catch (EOFException e) {
            throw new CryptoException("Encrypted value is not valid", e);
        } catch (IOException e) {
            throw new CryptoException("An error occurred while message decryption", e);
        }
    }

    /**
     * Builder of the {@link SymmetricCipher}.
     */
    public static class Builder implements io.helidon.common.Builder<Builder, SymmetricCipher> {

        private String algorithm = ALGORITHM_AES_GCM;
        private String provider = null;
        private Integer numberOfIterations = 10000;
        private Integer keySize = 256;
        private char[] password;

        private Builder() {
        }

        /**
         * Set algorithm which should be used.
         * <br>
         * Default value is {@link #ALGORITHM_AES_GCM}.
         *
         * @param algorithm algorithm to be used
         * @return updated builder instance
         */
        public Builder algorithm(String algorithm) {
            this.algorithm = Objects.requireNonNull(algorithm, "Algorithm cannot be null");
            return this;
        }

        /**
         * Set provider of the algorithm.
         *
         * @param provider provider to be used
         * @return updated builder instance
         */
        public Builder provider(String provider) {
            this.provider = provider;
            return this;
        }

        /**
         * Set password upon which the cryptography key will be generated.
         *
         * @param password base password
         * @return updated builder instance
         */
        public Builder password(char[] password) {
            Objects.requireNonNull(password, "Password cannot be null");
            this.password = password.clone();
            return this;
        }

        /**
         * Set size of the key (in bits) which should be generated.
         *
         * Default value is 256 bit.
         *
         * @param keySize size of the key
         * @return updated builder instance
         */
        public Builder keySize(int keySize) {
            this.keySize = keySize;
            return this;
        }

        /**
         * Number of iterations which will be used for key derivation from the password.
         *
         * Default value is 10000.
         *
         * @param numberOfIterations number of iterations
         * @return updated builder instance
         */
        public Builder numberOfIterations(int numberOfIterations) {
            this.numberOfIterations = numberOfIterations;
            return this;
        }

        @Override
        public SymmetricCipher build() {
            if (password == null) {
                throw new CryptoException("Password has to be specified.");
            }
            return new SymmetricCipher(this);
        }
    }

}
